{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[1. , 2.1],\n",
       "        [2. , 1.1],\n",
       "        [1.3, 1. ],\n",
       "        [1. , 1. ],\n",
       "        [2. , 1. ]]),\n",
       " array([ 1,  1, -1, -1,  1]))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def load_data():\n",
    "    x = np.array([[1.0, 2.1], [2.0, 1.1], [1.3, 1.0], [1.0, 1.0], [2.0, 1.0]])\n",
    "    y = np.array([1, 1, -1, -1, 1])\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x, y = load_data()\n",
    "x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.2, 0.2, 0.2, 0.2, 0.2])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N = len(x)\n",
    "\n",
    "#数据权重,在这里我们认为所有的数据等权重\n",
    "D = np.empty(5)\n",
    "D.fill(1 / N)\n",
    "D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([1. , 2.1]), -1)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Tree():\n",
    "    def __init__(self, col, value, eq):\n",
    "        self.col = col\n",
    "        self.value = value\n",
    "        self.eq = eq\n",
    "        #权重\n",
    "        self.weight = 1\n",
    "\n",
    "    #预测方法,简单的根据某个值分割数据\n",
    "    def __call__(self, xi):\n",
    "        if self.eq == '<':\n",
    "            if xi[self.col] < self.value:\n",
    "                return 1\n",
    "            return -1\n",
    "\n",
    "        if self.eq == '>':\n",
    "            if xi[self.col] >= self.value:\n",
    "                return 1\n",
    "            return -1\n",
    "\n",
    "\n",
    "tree = Tree(0, 1, '<')\n",
    "x[0], tree(x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6000000000000001"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss(tree):\n",
    "    loss = 0\n",
    "    for xi, yi, di in zip(x, y, D):\n",
    "        pred = tree(xi)\n",
    "        if pred != yi:\n",
    "            loss += di\n",
    "    return loss\n",
    "\n",
    "\n",
    "get_loss(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<__main__.Tree at 0x7fada8769ba8>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_tree():\n",
    "    min_loss = np.inf\n",
    "\n",
    "    min_col = 0\n",
    "    min_value = 0\n",
    "    min_eq = '<'\n",
    "\n",
    "    min_loss_tree = None\n",
    "\n",
    "    #遍历所有列\n",
    "    for col in range(x.shape[1]):\n",
    "\n",
    "        #遍历符号\n",
    "        for eq in ['<', '>']:\n",
    "\n",
    "            #从 列最小-0.1 遍历到 列最大+0.1\n",
    "            col_min = x[:, col].min() - 0.1\n",
    "            col_max = x[:, col].max() + 0.1\n",
    "\n",
    "            value = col_min\n",
    "\n",
    "            #遍历value值\n",
    "            while value < col_max:\n",
    "                tree = Tree(col, value, eq)\n",
    "                loss = get_loss(tree)\n",
    "\n",
    "                if loss < min_loss:\n",
    "                    min_loss = loss\n",
    "                    min_tree = tree\n",
    "\n",
    "                value += 0.1\n",
    "\n",
    "    return min_tree\n",
    "\n",
    "\n",
    "tree = get_tree()\n",
    "tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.2"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_loss(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-1 1\n",
      "1 1\n",
      "-1 -1\n",
      "-1 -1\n",
      "1 1\n"
     ]
    }
   ],
   "source": [
    "for xi, yi in zip(x, y):\n",
    "    pred = tree(xi)\n",
    "    print(pred, yi)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
